import scala.collection.mutable.ListBuffer
import scala.math.{pow, sqrt}

/**
  * Created by ariellemoro on 04.12.16.
  */

object KDTreeAlgo {

  def main(args: Array[String]): Unit = {
    //Point list : (2,3), (5,4), (9,6), (4,7), (8,1), (7,2)
    var points = new ListBuffer[ListBuffer[Int]]
    var point1 = new ListBuffer[Int]
    point1 += (2,3)
    var point2 = new ListBuffer[Int]
    point2 += (5,4)
    var point3 = new ListBuffer[Int]
    point3 += (9,6)
    var point4 = new ListBuffer[Int]
    point4 += (4,7)
    var point5 = new ListBuffer[Int]
    point5 += (8,1)
    var point6 = new ListBuffer[Int]
    point6 += (7,2)
    points += (point1, point2, point3, point4, point5, point6)
    val depth = 0
    //Print the KDTree
    val firstNode = buildKDTree(points, depth)
    println("KDTree =>")
    displayKDTree(firstNode, "first", 0)
    //Search test
    var searchPoint = new ListBuffer[Int]
    searchPoint += (4,8)
    val nearestPoint = nearestNeighbourSearch(searchPoint, firstNode)
    val nearestPointDesc = nearestPoint.toString()
    println(s"The nearest point on $searchPoint is $nearestPointDesc")
  }

  def buildKDTree(points: ListBuffer[ListBuffer[Int]], depth: Int): Node = {
    if(points.length == 0){
      return null
    }
    val k = points(0).length
    val axis = depth % k
    val median = points.length / 2
    var sortedPoints = points.sortWith((p1, p2) => p1(axis) < p2(axis))
    var medianPoint = sortedPoints(median)
    sortedPoints -= medianPoint
    val leftPoints = sortedPoints.filter(p => p(axis) < medianPoint(axis))
    val rightPoints = sortedPoints.filter(p => p(axis) > medianPoint(axis))
    return Node(medianPoint, buildKDTree(leftPoints, depth+1), buildKDTree(rightPoints, depth+1))
  }

  def nearestNeighbourSearch(searchPoint: ListBuffer[Int], currentNode: Node): Node = {
    if(currentNode.leftNode == null && currentNode.rightNode == null){
      return currentNode
    }
    //Local exploration
    val scores = new ListBuffer[Double]
    scores += compareTwoPoints(searchPoint, currentNode.xy)
    scores += compareTwoPoints(searchPoint, currentNode.leftNode.xy)
    scores +=  compareTwoPoints(searchPoint, currentNode.rightNode.xy)
    //Find the best score and the best index amongst the three nodes
    var bestScore = scores(0)
    var bestIndex = 0
    var index = 0
    while(index < scores.length){
      val cScore = scores(index)
      if(cScore < bestScore){
        bestScore = scores(index)
        bestIndex = index
      }
      index += 1
    }
    //Find the nearest neighbour
    if(bestIndex == 1){
      //LeftNode
      return nearestNeighbourSearch(searchPoint, currentNode.leftNode)
    }else if(bestIndex == 2){
      //RightNode
      return nearestNeighbourSearch(searchPoint, currentNode.rightNode)
    }else{
      //bestIndex = 0!
      return currentNode
    }
  }

  def compareTwoPoints(searchPoint: ListBuffer[Int], point:ListBuffer[Int]): Double = {
    return sqrt(pow((searchPoint(0) - point(0)), 2) + pow((searchPoint(1) - point(1)), 2))
  }

  def displayKDTree(node: Node, direction: String, depth: Int): Unit = {
    val nodeDesc = node.toString()
    println(s"=> ($depth - $direction) $nodeDesc")
    if(node.leftNode != null){
      displayKDTree(node.leftNode, "left", depth+1)
    }else{
      //println(s"=> ($depth - $direction) null")
    }
    if(node.rightNode != null){
      displayKDTree(node.rightNode, "right", depth+1)
    }else{
      //println(s"=> ($depth - $direction) null")
    }
  }

}
